{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Note\n",
    "\n",
    "View the README.md [here](https://github.com/eclipse/deeplearning4j-examples/blob/master/tutorials/README.md) to learn about installing, setting up dependencies and importing notebooks in Zeppelin"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Background\n",
    "\n",
    "Recurrent neural networks (RNN's) are used when the input is sequential in nature. Typically RNN's are much more effective than regular feed forward neural networks for sequential data because they can keep track of dependencies in the data over multiple time steps. This is possible because the output of a RNN at a time step depends on the current input and the output of the previous time step. \n",
    "\n",
    "RNN's can also be applied to situations where the input is sequential but the output isn't. In these cases the output of the last time step of the RNN is typically taken as the output for the overall observation. For classification, the output of the last time step will be the predicted class label for the observation. \n",
    "\n",
    "In this notebook we will show how to build a RNN using the MultiLayerNetwork class of deeplearning4j (DL4J). This tutorial will focus on applying a RNN for a classification task. We will be using the MNIST data, which is a dataset that consists of images of handwritten digits, as the input for the RNN. Although the MNIST data isn't time series in nature, we can interpret it as such since there are 784 inputs. Thus, each observation or image will be interpreted to have 784 time steps consisting of one scalar value for a pixel. Note that we use a RNN for this task for purely pedagogical reasons. In practice, convolutional neural networks (CNN's) are better suited for image classification tasks. \n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "autoscroll": "auto"
   },
   "outputs": [],
   "source": [
    "import org.deeplearning4j.eval.Evaluation\n",
    "import org.deeplearning4j.nn.api.OptimizationAlgorithm\n",
    "import org.deeplearning4j.nn.conf.MultiLayerConfiguration\n",
    "import org.deeplearning4j.nn.conf.NeuralNetConfiguration\n",
    "import org.deeplearning4j.nn.conf.Updater\n",
    "import org.deeplearning4j.nn.multilayer.MultiLayerNetwork\n",
    "import org.deeplearning4j.nn.weights.WeightInit\n",
    "import org.deeplearning4j.nn.conf.layers.{DenseLayer, GravesLSTM, OutputLayer, RnnOutputLayer}\n",
    "import org.deeplearning4j.nn.conf.distribution.UniformDistribution\n",
    "import org.deeplearning4j.nn.conf.layers.GravesLSTM\n",
    "import org.deeplearning4j.nn.conf.layers.RnnOutputLayer\n",
    "import org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator\n",
    "import org.deeplearning4j.optimize.listeners.ScoreIterationListener\n",
    "\n",
    "import org.datavec.api.split.NumberedFileInputSplit\n",
    "import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader\n",
    "\n",
    "import org.nd4j.linalg.dataset.DataSet\n",
    "import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction\n",
    "import org.nd4j.linalg.api.ndarray.INDArray\n",
    "import org.nd4j.linalg.activations.Activation\n",
    "import org.nd4j.linalg.dataset.api.iterator.DataSetIterator\n",
    "\n",
    "import org.slf4j.Logger\n",
    "import org.slf4j.LoggerFactory\n",
    "import org.apache.commons.io.IOUtils\n",
    "\n",
    "import java.nio.charset.Charset\n",
    "import java.util.Random\n",
    "import java.net.URL"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Download the dataset\n",
    "\n",
    "UCI has a number of datasets available for machine learning, make sure you have enough space on your local disk. The UCI synthetic control dataset can be found at [http://archive.ics.uci.edu/ml/datasets/synthetic+control+chart+time+series](http://archive.ics.uci.edu/ml/datasets/synthetic+control+chart+time+series). The code below will check if the data already exists and download the file."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "autoscroll": "auto"
   },
   "outputs": [],
   "source": [
    "val dataPath = new File(cache, \"/uci_synthetic_control/\")\n",
    "\n",
    "if(!dataPath.exists()) {\n",
    "    val url = \"https://archive.ics.uci.edu/ml/machine-learning-databases/synthetic_control-mld/synthetic_control.data\"\n",
    "    println(\"Downloading file...\")\n",
    "    val data = IOUtils.toString(new URL(url), Charset.defaultCharset())\n",
    "    val lines = data.split(\"\\n\")\n",
    "\n",
    "    var lineCount = 0;\n",
    "    var index = 0\n",
    "\n",
    "    val linesList = scala.collection.mutable.ListBuffer.empty[String]\n",
    "    println(\"Extracting file...\")\n",
    "\n",
    "    for (line <- lines) {\n",
    "        val count = new java.lang.Integer(lineCount / 100)\n",
    "        var newLine: String = null\n",
    "        newLine = line.replaceAll(\"\\\\s+\", \", \" + count.toString() + \"\\n\")\n",
    "        newLine = line + \", \" + count.toString()\n",
    "        linesList.add(newLine)\n",
    "        lineCount += 1\n",
    "    }\n",
    "    util.Random.shuffle(linesList)\n",
    "\n",
    "    for (line <- linesList) {\n",
    "        val outPath = new File(dataPath, index + \".csv\")\n",
    "        FileUtils.writeStringToFile(outPath, line, Charset.defaultCharset())\n",
    "        index += 1\n",
    "    }\n",
    "    println(\"Done.\")\n",
    "} else {\n",
    "    println(\"File already exists.\")\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Iterating from disk\n",
    "\n",
    "Now that we've saved our dataset to a CSV sequence format, we need to set up a `CSVSequenceRecordReader` and iterator that will read our saved sequences and feed them to our network. If you have already saved your data to disk, you can run this code block (and remaining code blocks) as much as you want without preprocessing the dataset again. Convenient!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "autoscroll": "auto"
   },
   "outputs": [],
   "source": [
    "val batchSize = 128\n",
    "val numLabelClasses = 6\n",
    "\n",
    "// training data\n",
    "val trainRR = new CSVSequenceRecordReader(0, \", \")\n",
    "trainRR.initialize(new NumberedFileInputSplit(dataPath.getAbsolutePath() + \"/%d.csv\", 0, 449))\n",
    "val trainIter = new SequenceRecordReaderDataSetIterator(trainRR, batchSize, numLabelClasses, 1)\n",
    "\n",
    "// testing data\n",
    "val testRR = new CSVSequenceRecordReader(0, \", \")\n",
    "testRR.initialize(new NumberedFileInputSplit(dataPath.getAbsolutePath() + \"/%d.csv\", 450, 599))\n",
    "val testIter = new SequenceRecordReaderDataSetIterator(testRR, batchSize, numLabelClasses, 1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Configuring a RNN for Classification\n",
    "Once everything needed is imported we can jump into the code. To build the neural network, we can use a set up like what is shown below. Because there are 784 timesteps and 10 class labels, nIn is set to 784 and nOut is set to 10 in the MultiLayerNetwork configuration. \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "autoscroll": "auto"
   },
   "outputs": [],
   "source": [
    "val conf = new NeuralNetConfiguration.Builder()\n",
    "    .seed(123)    //Random number generator seed for improved repeatability. Optional.\n",
    "    .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)\n",
    "    .iterations(1)\n",
    "    .weightInit(WeightInit.XAVIER)\n",
    "    .updater(Updater.NESTEROVS)\n",
    "    .learningRate(0.005)\n",
    "    .gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue)  //Not always required, but helps with this data set\n",
    "    .gradientNormalizationThreshold(0.5)\n",
    "    .list()\n",
    "    .layer(0, new GravesLSTM.Builder().activation(Activation.TANH).nIn(1).nOut(10).build())\n",
    "    .layer(1, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT)\n",
    "            .activation(Activation.SOFTMAX).nIn(10).nOut(numLabelClasses).build())\n",
    "    .pretrain(false).backprop(true).build();\n",
    "\n",
    "val model: MultiLayerNetwork = new MultiLayerNetwork(conf)\n",
    "model.setListeners(new ScoreIterationListener(20))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Training the classifier\n",
    "\n",
    "To train the model, pass the training iterator to the model's `fit()` method. We can use a loop to train the model using a prespecified number of epochs or passes through the training data. \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "autoscroll": "auto"
   },
   "outputs": [],
   "source": [
    "val numEpochs = 1\n",
    "(1 to numEpochs).foreach(_ => model.fit(trainIter) )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Model Evaluation\n",
    "Once training is complete we only a couple lines of code to evaluate the model on a test set. Using a test set to evaluate the model typically needs to be done in order to avoid overfitting on the training data. If we overfit on the training data, we have essentially fit to the noise in the data. \n",
    "\n",
    "An `Evaluation` class has more built-in methods if you need to extract a confusion matrix, and other tools are also available for calculating the Area Under Curve (AUC)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "autoscroll": "auto"
   },
   "outputs": [],
   "source": [
    "val evaluation = model.evaluate(testIter)\n",
    "\n",
    "// print the basic statistics about the trained classifier\n",
    "println(\"Accuracy: \"+evaluation.accuracy())\n",
    "println(\"Precision: \"+evaluation.precision())\n",
    "println(\"Recall: \"+evaluation.recall())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### What's next?\n",
    "\n",
    "- Check out all of our tutorials available [on Github](https://github.com/eclipse/deeplearning4j/tree/master/dl4j-examples/tutorials). Notebooks are numbered for easy following."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Spark 2.0.0 - Scala 2.11",
   "language": "scala",
   "name": "spark2-scala"
  },
  "language_info": {
   "codemirror_mode": "text/x-scala",
   "file_extension": ".scala",
   "mimetype": "text/x-scala",
   "name": "scala",
   "pygments_lexer": "scala",
   "version": "2.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
